package com.xunmall.security.acl.utils;

import com.google.common.collect.Lists;
import com.xunmall.base.dto.Pair;
import com.xunmall.base.util.StringUtils;
import com.xunmall.security.acl.AclInterceptor;
import com.xunmall.security.acl.annotation.ValueType;
import lombok.extern.slf4j.Slf4j;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.ExpressionVisitorAdapter;
import net.sf.jsqlparser.expression.StringValue;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.parser.CCJSqlParserManager;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.StatementVisitorAdapter;
import net.sf.jsqlparser.statement.select.*;

import java.io.StringReader;
import java.util.List;
import java.util.Set;

/**
 * @Author: WangYanjing
 * @Date: 2018/12/27 10:53
 * @Description: 通过解析sql，发现需要数据权限控制的表，自动添加where或on条件过滤
 */
@Slf4j
public class SqlModifierHelper extends AclHelper {

    private List<Pair<String, String>> tableInfos = Lists.newArrayList();    //表名,alias
    private boolean change = false;

    public static Pair<Boolean, String> processSql(String originalSql) {
        SqlModifierHelper modifier = new SqlModifierHelper();
        String newSql = modifier.sqlModify(originalSql);
        Pair<Boolean, String> result = new Pair(modifier.change, newSql);
        return result;
    }

    public String sqlModify(String sql) {
        CCJSqlParserManager parserManager = new CCJSqlParserManager();
        Statement stmt;
        try {
            stmt = parserManager.parse(new StringReader(sql));
            //使用visitor模式访问SQL的各个组成部分
            stmt.accept(new StatementVisitor());
        } catch (JSQLParserException e) {
            return null;
        }
        return stmt.toString();
    }

    class StatementVisitor extends StatementVisitorAdapter {
        //访问select语句
        @Override
        public void visit(Select select) {
            select.getSelectBody().accept(new SelectVisitor());
        }
    }

    class SelectVisitor extends SelectVisitorAdapter {
        @Override
        public void visit(PlainSelect ps) {
            //初始化vistor
            ExpressionVisitor expressionVisitor = new ExpressionVisitor();
            expressionVisitor.setSelectVisitor(this);
            SelectItemVisitor selectItemVisitor = new SelectItemVisitor();
            selectItemVisitor.setExpressionVisitor(expressionVisitor);
            FromItemVisitor fromItemVisitor = new FromItemVisitor();

            //访问select中的from部分
            ps.getFromItem().accept(fromItemVisitor);
            //添加where条件
            Expression exp = createExpression(tableInfos, ps.getWhere());
            ps.setWhere(exp);

            //访问select中的join部分
            List<Join> joins = ps.getJoins();
            if (joins != null && joins.size() > 0) {
                for (Join join : joins) {
                    join.getRightItem().accept(fromItemVisitor);
                    //添加on条件
                    Expression expOn = createExpression(tableInfos, join.getOnExpression());
                    join.setOnExpression(expOn);
                }
            }

            //访问select中的where部分
            if (ps.getWhere() != null) {
                ps.getWhere().accept(expressionVisitor);
            }
            //添加where中子查询条件
            Expression expWhere = createExpression(tableInfos, ps.getWhere());
            ps.setWhere(expWhere);

            Limit limit = ps.getLimit();
            if (limit != null) {
                SqlLimit newLimit = new SqlLimit(limit);
                ps.setLimit(newLimit);
            }
        }

        private Expression createExpression(List<Pair<String, String>> tableInfos, Expression oldExp) {
            Expression result = oldExp;
            for (Pair<String, String> table : tableInfos) {
                Pair<String, ValueType> pair = AclInterceptor.getAcl().get(table.getFst());
                if (pair != null) {
                    List<Expression> expressions = Lists.newArrayList();
                    Set<String> controlValue = getControlValue(pair.getSnd());
                    if (controlValue.size() > 0) {  //检验是否有必要做数据过滤
                        controlValue.stream().forEach(item ->
                                expressions.add(new StringValue("'" + item + "'")));
                        String columnName = pair.getFst();
                        if (StringUtils.isNotEmpty(table.getSnd())) {
                            columnName = table.getSnd() + "." + columnName;
                        }
                        Expression newExp = new InExpression(new Column(columnName), new ExpressionList(expressions));
                        if (oldExp != null) {
                            AndExpression and = new AndExpression(oldExp, newExp);
                            result = and;
                        } else {
                            result = newExp;
                        }
                        change = true;
                    }
                }
            }
            tableInfos.clear();
            return result;
        }

        @Override
        public void visit(SetOperationList setOpList) {
            List<SelectBody> selects = setOpList.getSelects();
            for (SelectBody select : selects) {
                select.accept(this);
            }
        }
    }

    class FromItemVisitor extends FromItemVisitorAdapter {
        @Override
        public void visit(Table table) {
            Pair<String, String> t = new Pair(table.getName().toUpperCase().replace("`", ""),
                    table.getAlias() == null ? null : table.getAlias().toString());
            tableInfos.add(t);
        }

        @Override
        public void visit(SubSelect ss) {
            ss.getSelectBody().accept(new SelectVisitor());
        }

        @Override
        public void visit(SubJoin sj) {
            sj.getLeft().accept(this);
            sj.getJoin().getRightItem().accept(this);
        }
    }

    class ExpressionVisitor extends ExpressionVisitorAdapter {
        @Override
        public void visit(SubSelect subSelect) {
            if (getSelectVisitor() != null) {
                subSelect.getSelectBody().accept(getSelectVisitor());
            }
        }
    }

    static class SelectItemVisitor extends SelectItemVisitorAdapter {
        ExpressionVisitor expressionVisitor;

        public void setExpressionVisitor(ExpressionVisitor expressionVisitor) {
            this.expressionVisitor = expressionVisitor;
        }

        @Override
        public void visit(SelectExpressionItem item) {
            item.getExpression().accept(expressionVisitor);
        }
    }
}
